
import os
import numpy as np
import torch
import pyiqa
from transformers import BlipForQuestionAnswering
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from transformers import ViTModel, ViTFeatureExtractor, SwinModel,ViTMSNModel,AutoFeatureExtractor
from pytorch_fid.inception import InceptionV3
from pytorch_fid.fid_score import get_activations
from tqdm import tqdm
import clip
from torchvision import models, transforms
from PIL import Image
from collections import Counter
import random
from sklearn.utils import resample
from sklearn.preprocessing import MinMaxScaler
import itertools
from utils_dqa import compute_mmd

def process_images(image_paths, model, preprocess, device, batch_size=32, 
                   moco=False, vit=False, 
                   inception=False, clip=False,swin=False,msn=False):
    embeddings = []
    num_images = len(image_paths)
    for i in tqdm(range(0, num_images, batch_size)):
        batch_paths = image_paths[i:i+batch_size]
        images = [Image.open(p).convert('RGB') for p in batch_paths]
        processed_images = None
        with torch.no_grad():
            if inception:
                activations = get_activations(batch_paths, model, batch_size, 2048, device)
                embeddings.append(activations)
            elif moco:
                processed_images = torch.stack([preprocess(img) for img in images]).to(device)
                features = model(processed_images).cpu().numpy()
                embeddings.append(features)
            elif vit or swin:
                processed_image = preprocess(images=images, return_tensors="pt").to(device)
                features = model(processed_image['pixel_values']).last_hidden_state[:, 0, :].detach().cpu().numpy()
                embeddings.append(features)
            elif msn:
                processed_image = preprocess(images=images, return_tensors="pt").to(device)
                features = model(**processed_image)
                out = features.last_hidden_state[:, 0, :].detach().cpu().numpy()
                embeddings.append(out)
            elif clip:
                processed_images = torch.stack([preprocess(img) for img in images]).to(device)
                features = model.encode_image(processed_images).cpu().numpy()
                embeddings.append(features)
            else:
                processed_images = torch.stack([preprocess(img) for img in images]).to(device)
                features = model(processed_images).cpu().numpy()
                embeddings.append(features)

    activations = np.concatenate(embeddings, axis=0)
    return activations


class BiasEvaluator():
    def __init__(self, args, device):
        self.args = args
        self.device = device

        self.sdxl_dir = os.path.join(args.root_dir, f'sdxl_embedding')
        os.makedirs(self.sdxl_dir, exist_ok=True)

        
        if args.iqa=='1':
            # Inception
            self.inception_model = InceptionV3([InceptionV3.BLOCK_INDEX_BY_DIM[2048]]).to(self.device)
            self.inception_model.eval()

            # VGG
            self.vgg_model = models.vgg16(pretrained=True).to(self.device)
            self.vgg_model.eval()

            # ResNet-50
            self.resnet_model = models.resnet50(pretrained=True)
            self.resnet_model.fc = torch.nn.Identity()
            self.resnet_model = self.resnet_model.to(self.device)
            self.resnet_model.eval()

            # ViT (ImageNet 1K & 21K)
            self.vit_model_1k = ViTModel.from_pretrained('google/vit-base-patch16-224').to(self.device)
            self.vit_model_21k = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(self.device)
            self.vit_preprocess_1k = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
            self.vit_preprocess_21k = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

            # Swin Transformer (ImageNet 1K & 21K)
            self.swin_model_1k = SwinModel.from_pretrained("microsoft/swin-base-patch4-window7-224").to(self.device)
            self.swin_model_21k = SwinModel.from_pretrained("microsoft/swin-base-patch4-window7-224-in22k").to(self.device)
            self.swin_preprocess_1k =  AutoFeatureExtractor.from_pretrained("microsoft/swin-base-patch4-window7-224")
            self.swin_preprocess_21k =  AutoFeatureExtractor.from_pretrained("microsoft/swin-base-patch4-window7-224-in22k")
            self.base_preprocess = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

            # MOCOv2
            self.moco_model = self.load_moco_model()

            # DINO
            self.dino_model_resnet = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50').to(self.device)
            self.dino_model_vit = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').to(self.device)

            # CLIP Models (ViT and ResNet)
            self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device)
            self.clip_resnet_model, self.clip_resnet_preprocess = clip.load("RN50", device=self.device)

            # MSN (Masked Siamese Networks)
            # self.msn_model = AutoModel.from_pretrained("facebook/msn-vit-base").to(self.device)
            self.msn_processor = AutoFeatureExtractor.from_pretrained("facebook/vit-msn-base")
            self.msn_model = ViTMSNModel.from_pretrained("facebook/vit-msn-base").to(self.device)

        elif args.iqa=='3':
            self.iqa_metric_nr_face = pyiqa.create_metric('topiq_nr-face', device=self.device)
            self.iqa_metric_nr_flive = pyiqa.create_metric('topiq_nr-flive', device=self.device)
        elif args.iqa=='2':
            self.blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
            self.blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").cuda()
            model_id = "google/paligemma-3b-mix-224"
            self.paligemma_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).cuda()
            self.paligemma_processor = AutoProcessor.from_pretrained(model_id)

    def load_moco_model(self):
        model = models.resnet50(pretrained=False)
        model.fc = torch.nn.Identity()
        checkpoint = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location=self.device)
        state_dict = checkpoint['state_dict']
        model.load_state_dict({k.replace('module.encoder_q.', ''): v for k, v in state_dict.items() if 'fc' not in k})
        return model.to(self.device).eval()


    def save_sdxl_results(self, profession, gender, race, category,metric_name, raw_result):
        file_path = os.path.join(self.sdxl_dir, f'{profession}_{gender}_{race}_{category}_{gender}_{metric_name}.pt')
        torch.save(raw_result, file_path)
    def load_sdxl_results(self, profession, gender, race, category,metric_name='IQA1'):
        file_path = os.path.join(self.sdxl_dir, f'{profession}_{gender}_{race}_{category}_{gender}_{metric_name}.pt')
        if os.path.exists(file_path):
            return torch.load(file_path)
        else:
            return None
    def check_sdxl_results(self, profession, gender, race, category,metric_name):
        file_path = os.path.join(self.sdxl_dir, f'{profession}_{gender}_{race}_{category}_{gender}_{metric_name}.pt')
        if os.path.exists(file_path):
            return 0
        else:
            print(f"Extract {file_path}")
            return None
    def extract_sdxl_IQA1(self, image_paths, profession, gender, race, category):
        saved_result = self.check_sdxl_results(profession, gender, race, category, 'IQA1')
        if saved_result is not None:
            return
        else:
            print("MSN")
            act_msn = process_images(image_paths, model=self.msn_model, preprocess=self.msn_processor, device=self.device,msn=True)
            print("Inception")
            act_inception = process_images(image_paths, model=self.inception_model, preprocess=None, device=self.device, inception=True,batch_size=32)
            print("VGG")
            act_vgg = process_images(image_paths, model=self.vgg_model, preprocess=self.base_preprocess, device=self.device)
            print("ResNet")
            act_resnet = process_images(image_paths, model=self.resnet_model, preprocess=self.base_preprocess, device=self.device, resnet=True)
            print("ViT1K")
            act_vit_1k = process_images(image_paths, model=self.vit_model_1k, preprocess=self.vit_preprocess_1k, device=self.device, vit=True)
            print("ViT21K")
            act_vit_21k = process_images(image_paths, model=self.vit_model_21k, preprocess=self.vit_preprocess_21k, device=self.device, vit=True)
            print("Swin1K")
            act_swin_1k = process_images(image_paths, model=self.swin_model_1k, preprocess=self.swin_preprocess_1k, device=self.device,swin=True)
            print("Swin21K")
            act_swin_21k = process_images(image_paths, model=self.swin_model_21k, preprocess=self.swin_preprocess_21k, device=self.device,swin=True)
            print("MOCO")
            act_moco = process_images(image_paths, model=self.moco_model, preprocess=self.base_preprocess, device=self.device, moco=True)
            print("DINO ResNet")
            act_dino_resnet = process_images(image_paths, model=self.dino_model_resnet, preprocess=self.base_preprocess, device=self.device, dino=True)
            print("DINO ViT")
            act_dino_vit = process_images(image_paths, model=self.dino_model_vit, preprocess=self.base_preprocess, device=self.device, dino=True)
            print("CLIP ViT")
            act_clip_vit = process_images(image_paths, model=self.clip_model, preprocess=self.clip_preprocess, device=self.device, use_clip=True)
            print("CLIP ResNet")
            act_clip_resnet = process_images(image_paths, model=self.clip_resnet_model, preprocess=self.clip_resnet_preprocess, device=self.device, use_clip=True)
            # Save all variations clearly identified
            raw_result = {
                'inception': act_inception,
                'vgg': act_vgg,
                'resnet': act_resnet,
                'vit_1k': act_vit_1k,
                'vit_21k': act_vit_21k,
                'swin_1k': act_swin_1k,
                'swin_21k': act_swin_21k,
                'moco': act_moco,
                'dino_resnet': act_dino_resnet,
                'dino_vit': act_dino_vit,
                'clip_vit': act_clip_vit,
                'clip_resnet': act_clip_resnet,
                'msn': act_msn,
            }
            
            # Save the raw results for future use
            self.save_sdxl_results(profession, gender, race, category, 'IQA1', raw_result)
            return raw_result
    def compute_DQA(self,scaler, ref_A_embed, target_A_embed, ref_B_embed, target_B_embed,verbose=False):
        
        ref_data_embed = np.vstack((ref_A_embed, ref_B_embed))
        target_data_embed = np.vstack((target_A_embed, target_B_embed))
        scaler = MinMaxScaler()
        scaler.fit(ref_data_embed)
        target_data_embed = scaler.transform(target_data_embed)
        ref_data_embed = scaler.transform(ref_data_embed)
        ref_A_embed = scaler.transform(ref_A_embed)
        ref_B_embed = scaler.transform(ref_B_embed)
        target_A_embed = scaler.transform(target_A_embed)
        target_B_embed = scaler.transform(target_B_embed)

        combined_data = np.vstack((ref_data_embed, target_data_embed))
        
        std = np.std(combined_data)

        if np.isinf(std):
            variances = np.var(combined_data, axis=0)
            std = np.mean(variances)**(1/2)
        a_gap = compute_mmd(ref_A_embed, target_A_embed)
        b_gap = compute_mmd(ref_B_embed, target_B_embed)
        term1 = abs(a_gap - b_gap)
        term3 = compute_mmd(ref_data_embed, target_data_embed)

        DQA = abs(term1/term3+1e-3)
        if np.isnan(DQA):
            DQA=0
        if np.isinf(DQA):
            DQA=0
        return DQA

    
    def compute_DQA_trapping(self, scaler, ref_A_embed, target_A_embed, ref_B_embed, target_B_embed, n_iterations=10, random_split=False, verbose=False,cat1=False):
        DQA_bootstrap = []

        for i in tqdm(range(n_iterations)):
            if cat1:
                # Apply random split when ref_A_embed and target_A_embed are the same set
                split_size_A = int(len(ref_A_embed) * 0.5)
                indices_A = np.random.permutation(len(ref_A_embed))
                split_size_B = int(len(ref_B_embed) * 0.5)
                indices_B = np.random.permutation(len(ref_B_embed))

                # Split the embeddings into halves based on the permutation
                ref_A_resample = ref_A_embed[indices_A[:split_size_A]]
                target_A_resample = ref_A_embed[indices_A[split_size_A:]]
                ref_B_resample = ref_B_embed[indices_B[:split_size_B]]
                target_B_resample = ref_B_embed[indices_B[split_size_B:]]

                ref_A_resample = resample(ref_A_resample, replace=True, n_samples=len(ref_A_resample))
                ref_B_resample = resample(ref_B_resample, replace=True, n_samples=len(ref_B_resample))
                target_A_resample = resample(target_A_resample, replace=True, n_samples=len(target_A_resample))
                target_B_resample = resample(target_B_resample, replace=True, n_samples=len(target_B_resample))
            else:
                # Otherwise, resample with replacement for both A and B embeddings
                ref_A_resample = resample(ref_A_embed, replace=True, n_samples=len(ref_A_embed))
                ref_B_resample = resample(ref_B_embed, replace=True, n_samples=len(ref_B_embed))
                target_A_resample = resample(target_A_embed, replace=True, n_samples=len(target_A_embed))
                target_B_resample = resample(target_B_embed, replace=True, n_samples=len(target_B_embed))

            # Compute DQA and CSD for the resampled data
            DQA = self.compute_DQA(scaler, ref_A_resample, target_A_resample, ref_B_resample, target_B_resample, verbose=verbose)

            # Store the results
            DQA_bootstrap.append(DQA)

        DQA_bootstrap = np.array(DQA_bootstrap)
        DQA_mean = np.mean(DQA_bootstrap)
        return DQA_mean
    def run_sdxl_DQA(self,profession,model_name, race_list,mode):
        seed = 0
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        torch.backends.cudnn.deterministic=True
        torch.backends.cudnn.benchmarks=False
        os.environ['PYTHONHASHSEED'] = str(seed)
        verbose=False
        if mode=='gender':
            cat1_male_embed_list = []
            cat1_female_embed_list = []
            cat2_male_embed_list = []
            cat2_female_embed_list = []
            cat3_male_embed_list = []
            cat3_female_embed_list = []
            cat4_male_embed_list = []
            cat4_female_embed_list = []
            cat5_male_embed_list = []
            cat5_female_embed_list = []
            cat6_male_embed_list = []
            cat6_female_embed_list = []
            for race in race_list:
                male_result_race1 = self.load_sdxl_results(profession, 'male', race, 1)
                
                cat1_male_embed_list.append(male_result_race1[model_name])
                female_result_race1 = self.load_sdxl_results(profession, 'female', race, 1)
                cat1_female_embed_list.append(female_result_race1[model_name])

                male_result_race2 = self.load_sdxl_results(profession, 'male', race, 2)
                cat2_male_embed_list.append(male_result_race2[model_name])
                female_result_race2 = self.load_sdxl_results(profession, 'female', race, 2)
                cat2_female_embed_list.append(female_result_race2[model_name])

                male_result_race3 = self.load_sdxl_results(profession, 'male', race, 3)
                cat3_male_embed_list.append(male_result_race3[model_name])
                female_result_race3 = self.load_sdxl_results(profession, 'female', race, 3)
                cat3_female_embed_list.append(female_result_race3[model_name])

                male_result_race4 = self.load_sdxl_results(profession, 'male', race, 4)
                cat4_male_embed_list.append(male_result_race4[model_name])
                female_result_race4 = self.load_sdxl_results(profession, 'female', race, 4)
                cat4_female_embed_list.append(female_result_race4[model_name])

                male_result_race5 = self.load_sdxl_results(profession, 'male', race, 5)
                cat5_male_embed_list.append(male_result_race5[model_name])
                female_result_race5 = self.load_sdxl_results(profession, 'female', race, 5)
                cat5_female_embed_list.append(female_result_race5[model_name])

                male_result_race6 = self.load_sdxl_results(profession, 'male', race, 6)
                cat6_male_embed_list.append(male_result_race6[model_name])
                female_result_race6 = self.load_sdxl_results(profession, 'female', race, 6)
                cat6_female_embed_list.append(female_result_race6[model_name])

            cat1_male_embed = np.vstack(cat1_male_embed_list)
            cat2_male_embed = np.vstack(cat2_male_embed_list)
            cat3_male_embed = np.vstack(cat3_male_embed_list)
            cat4_male_embed = np.vstack(cat4_male_embed_list)
            cat5_male_embed = np.vstack(cat5_male_embed_list)
            cat6_male_embed = np.vstack(cat6_male_embed_list)
            cat1_female_embed = np.vstack(cat1_female_embed_list)
            cat2_female_embed = np.vstack(cat2_female_embed_list)
            cat3_female_embed = np.vstack(cat3_female_embed_list)
            cat4_female_embed = np.vstack(cat4_female_embed_list)
            cat5_female_embed = np.vstack(cat5_female_embed_list)
            cat6_female_embed = np.vstack(cat6_female_embed_list)
            cat_all = np.vstack((cat1_male_embed, cat1_female_embed,cat2_male_embed, cat2_female_embed, \
                                 cat3_male_embed, cat3_female_embed,cat4_male_embed, cat4_female_embed,\
                                    cat5_male_embed, cat5_female_embed,cat6_male_embed, cat6_female_embed))
            
            scaler = MinMaxScaler()
            scaler.fit(cat_all)
            DQA1 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat1_male_embed, cat1_female_embed,cat1_female_embed)
            DQA2 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat2_male_embed, cat1_female_embed,cat2_female_embed)
            DQA3 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat3_male_embed, cat1_female_embed,cat3_female_embed)
            DQA4 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat4_male_embed, cat1_female_embed,cat4_female_embed)
            DQA5 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat5_male_embed, cat1_female_embed,cat5_female_embed)
            DQA6 = self.compute_DQA_trapping(scaler, cat1_male_embed, cat6_male_embed, cat1_female_embed,cat6_female_embed)

            return DQA1,DQA2,DQA3,DQA4,DQA5,DQA6
        
        
        elif mode=='race':
            cat1_race_dict = {k:[] for k in race_list}
            cat2_race_dict = {k:[] for k in race_list}
            cat3_race_dict = {k:[] for k in race_list}
            cat4_race_dict = {k:[] for k in race_list}
            cat5_race_dict = {k:[] for k in race_list}
            cat6_race_dict = {k:[] for k in race_list}

            for race in race_list:
                male_result_race1 = self.load_sdxl_results(profession, 'male', race, 1)
                female_result_race1 = self.load_sdxl_results(profession, 'female', race, 1)
                male_result_race2 = self.load_sdxl_results(profession, 'male', race, 2)
                female_result_race2 = self.load_sdxl_results(profession, 'female', race, 2)
                male_result_race3 = self.load_sdxl_results(profession, 'male', race, 3)
                female_result_race3 = self.load_sdxl_results(profession, 'female', race, 3)
                male_result_race4 = self.load_sdxl_results(profession, 'male', race, 4)
                female_result_race4 = self.load_sdxl_results(profession, 'female', race, 4)
                male_result_race5 = self.load_sdxl_results(profession, 'male', race, 5)
                female_result_race5 = self.load_sdxl_results(profession, 'female', race, 5)
                male_result_race6 = self.load_sdxl_results(profession, 'male', race, 6)
                female_result_race6 = self.load_sdxl_results(profession, 'female', race, 6)
                cat1_embed = np.vstack([male_result_race1[model_name], female_result_race1[model_name]])
                cat2_embed = np.vstack([male_result_race2[model_name], female_result_race2[model_name]])
                cat3_embed = np.vstack([male_result_race3[model_name], female_result_race3[model_name]])
                cat4_embed = np.vstack([male_result_race4[model_name], female_result_race4[model_name]])
                cat5_embed = np.vstack([male_result_race5[model_name], female_result_race5[model_name]])
                cat6_embed = np.vstack([male_result_race6[model_name], female_result_race6[model_name]])
                cat1_race_dict[race] = cat1_embed
                cat2_race_dict[race] = cat2_embed
                cat3_race_dict[race] = cat3_embed
                cat4_race_dict[race] = cat4_embed
                cat5_race_dict[race] = cat5_embed
                cat6_race_dict[race] = cat6_embed
            DQA1_array = []
            DQA2_array = []
            DQA3_array = []
            DQA4_array = []
            DQA5_array = []
            DQA6_array = []
            cat_all = []
            
            for race in race_list:
                cat_all.append(cat1_race_dict[race])
                cat_all.append(cat2_race_dict[race])
                cat_all.append(cat3_race_dict[race])
                cat_all.append(cat4_race_dict[race])
                cat_all.append(cat5_race_dict[race])
                cat_all.append(cat6_race_dict[race])
            cat_all = np.vstack(cat_all)
            scaler = MinMaxScaler()
            scaler.fit(cat_all)

            for combination in list(itertools.combinations(race_list, 2)):
                cat1_A_embed, cat1_B_embed = cat1_race_dict[combination[0]], cat1_race_dict[combination[1]]
                cat2_A_embed, cat2_B_embed = cat2_race_dict[combination[0]], cat2_race_dict[combination[1]]
                cat3_A_embed, cat3_B_embed = cat3_race_dict[combination[0]], cat3_race_dict[combination[1]]
                cat4_A_embed, cat4_B_embed = cat4_race_dict[combination[0]], cat4_race_dict[combination[1]]
                cat5_A_embed, cat5_B_embed = cat5_race_dict[combination[0]], cat5_race_dict[combination[1]]
                cat6_A_embed, cat6_B_embed = cat6_race_dict[combination[0]], cat6_race_dict[combination[1]]
                
                DQA1 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat1_A_embed, cat1_B_embed, cat1_B_embed,cat1=True)
                DQA2 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat2_A_embed, cat1_B_embed, cat2_B_embed)
                DQA3 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat3_A_embed, cat1_B_embed, cat3_B_embed)
                DQA4 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat4_A_embed, cat1_B_embed, cat4_B_embed)
                DQA5 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat5_A_embed, cat1_B_embed, cat5_B_embed)
                DQA6 = self.compute_DQA_trapping(scaler,cat1_A_embed, cat6_A_embed, cat1_B_embed, cat6_B_embed)

                # Append computed values to respective arrays
                DQA1_array.append(DQA1)
                DQA2_array.append(DQA2)
                DQA3_array.append(DQA3)
                DQA4_array.append(DQA4)
                DQA5_array.append(DQA5)
                DQA6_array.append(DQA6)
                
            DQA1_array = np.array(DQA1_array)
            DQA2_array = np.array(DQA2_array)
            DQA3_array = np.array(DQA3_array)
            DQA4_array = np.array(DQA4_array)
            DQA5_array = np.array(DQA5_array)
            DQA6_array = np.array(DQA6_array)
            
            # Calculate the average values
            avg_DQA1 = np.mean(DQA1_array)
            avg_DQA2 = np.mean(DQA2_array)
            avg_DQA3 = np.mean(DQA3_array)
            avg_DQA4 = np.mean(DQA4_array)
            avg_DQA5 = np.mean(DQA5_array)
            avg_DQA6 = np.mean(DQA6_array)
            return avg_DQA1,avg_DQA2,avg_DQA3,avg_DQA4,avg_DQA5,avg_DQA6
            
        
    def extract_IQA_2(self,image_paths, profession, gender,race, category):
        saved_result = self.check_sdxl_results(profession, gender, race, category, 'IQA2')
        if saved_result is not None:
            return
        else : 
            prompt1 = "Is this image real or fake?"
            prompt2 = "Are this person’s limbs distorted?"
            blip_IQA2_result_1=[]
            blip_IQA2_result_2=[]
            paligemma_IQA2_result_1=[]
            paligemma_IQA2_result_2=[]
            
            for image_path in tqdm(image_paths):
                image = Image.open(image_path).convert('RGB')

                
                inputs = self.blip_processor(image, prompt1, return_tensors="pt").to("cuda")
                out = self.blip_model.generate(**inputs)
                blip_decoded1 = self.blip_processor.decode(out[0], skip_special_tokens=True)

                inputs = self.blip_processor(image, prompt2, return_tensors="pt").to("cuda")
                out = self.blip_model.generate(**inputs)
                blip_decoded2 = self.blip_processor.decode(out[0], skip_special_tokens=True)


                blip_IQA2_result_1.append(blip_decoded1)
                blip_IQA2_result_2.append(blip_decoded2)
            
                model_inputs = self.paligemma_processor(text=prompt1, images=image, return_tensors="pt").to("cuda")
                input_len = model_inputs["input_ids"].shape[-1]
                with torch.inference_mode():
                    generation = self.paligemma_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
                    generation = generation[0][input_len:]
                    paligemma_decoded1 = self.paligemma_processor.decode(generation, skip_special_tokens=True)

                model_inputs = self.paligemma_processor(text=prompt2, images=image, return_tensors="pt").to("cuda")
                input_len = model_inputs["input_ids"].shape[-1]
                with torch.inference_mode():
                    generation = self.paligemma_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
                    generation = generation[0][input_len:]
                    paligemma_decoded2 = self.paligemma_processor.decode(generation, skip_special_tokens=True)
                
                paligemma_IQA2_result_1.append(paligemma_decoded1)
                paligemma_IQA2_result_2.append(paligemma_decoded2)
            raw_result={
                'blip_prompt1': blip_IQA2_result_1,
                'blip_prompt2':blip_IQA2_result_2,
                'paligemma_prompt1':paligemma_IQA2_result_1,
                'paligemma_prompt2':paligemma_IQA2_result_2}
            self.save_sdxl_results(profession, gender, race, category, 'IQA2', raw_result)
            return raw_result
        
    def extract_IQA_3(self,image_paths, profession, gender,race, category):
        saved_result = self.check_sdxl_results(profession, gender, race, category, 'IQA3')
        if saved_result is not None:
            return
        else : 
        
            IQA3_faces=[]
            IQA3_flive=[]
            for image_path in tqdm(image_paths):
                try : 
                    score_face = self.iqa_metric_nr_face(image_path)[0][0].item()    
                except AssertionError:
                    score_face=0
                score_flive = self.iqa_metric_nr_flive(image_path)[0][0].item()
                IQA3_faces.append(score_face)
                IQA3_flive.append(score_flive)
            IQA3_faces = np.array(IQA3_faces)
            IQA3_flive = np.array(IQA3_flive)
            raw_result = {
                'IQA3_faces': IQA3_faces,
                'IQA3_flive': IQA3_flive,
            }

            self.save_sdxl_results(profession, gender, race, category, 'IQA3', raw_result)
            return raw_result
        
    def DQA_IQA2_with_bootstrap(self, profession, model_name, race_list, mode, n_iterations=100):
        if model_name in ['blip_prompt1', 'blip_prompt2']:
            answer = 'real'
        else:
            answer = 'no'

        def dqa_iqa2(male, female, ref_male, ref_female, prompt_answer):
            male_score = Counter(male)[prompt_answer] / len(male)
            female_score = Counter(female)[prompt_answer] / len(female)
            ref_male_score = Counter(ref_male)[prompt_answer] / len(ref_male)
            ref_female_score = Counter(ref_female)[prompt_answer] / len(ref_female)
            denom = abs((ref_male_score + ref_female_score) - (male_score + female_score))
            num = abs(male_score - female_score)
            return num / (denom + 1e-3)

        # Prepare bootstrap results
        DQA_bootstrap = {f"DQA{cat}": [] for cat in range(1, 7)}

        for _ in range(n_iterations):
            if mode == 'gender':
                # Initialize containers for categories
                categories = {f"cat{cat}_male": [] for cat in range(1, 7)}
                categories.update({f"cat{cat}_female": [] for cat in range(1, 7)})

                for race in race_list:
                    for cat in range(1, 7):
                        male_result = self.load_sdxl_results(profession, 'male', race, cat, metric_name='IQA2')
                        female_result = self.load_sdxl_results(profession, 'female', race, cat, metric_name='IQA2')

                        categories[f"cat{cat}_male"].extend(resample(male_result[model_name], replace=True))
                        categories[f"cat{cat}_female"].extend(resample(female_result[model_name], replace=True))

                # Compute DQA for each category
                for cat in range(1, 7):
                    male = categories[f"cat{cat}_male"]
                    female = categories[f"cat{cat}_female"]

                    if cat == 1:  # Special case for cat1: ref_male and male are the same
                        ref_male = male
                        ref_female = female
                    else:  # For other categories, reference is based on cat1
                        ref_male = categories["cat1_male"]
                        ref_female = categories["cat1_female"]

                    # Compute DQA for the current category
                    DQA_sample = dqa_iqa2(male, female, ref_male, ref_female, answer)
                    DQA_bootstrap[f"DQA{cat}"].append(DQA_sample)

            elif mode == 'race':
                # Initialize containers for race combinations
                race_dict = {race: {f"cat{cat}": [] for cat in range(1, 7)} for race in race_list}

                for race in race_list:
                    for cat in range(1, 7):
                        male_result = self.load_sdxl_results(profession, 'male', race, cat, metric_name='IQA2')
                        female_result = self.load_sdxl_results(profession, 'female', race, cat, metric_name='IQA2')

                        race_dict[race][f"cat{cat}"].extend(resample(male_result[model_name] + female_result[model_name], replace=True))

                # Compute DQA for race combinations
                for cat in range(1, 7):
                    for race1, race2 in itertools.combinations(race_list, 2):
                        cat_race1 = race_dict[race1][f"cat{cat}"]
                        cat_race2 = race_dict[race2][f"cat{cat}"]

                        # For race mode, ref is the two compared races
                        ref_race1 = cat_race1
                        ref_race2 = cat_race2

                        DQA_sample = dqa_iqa2(cat_race1, cat_race2, ref_race1, ref_race2, answer)
                        DQA_bootstrap[f"DQA{cat}"].append(DQA_sample)

        return tuple(np.mean(DQA_bootstrap[f"DQA{cat}"]) for cat in range(1, 7))
    
    def DQA_IQA3_with_bootstrap(self, profession, model_name, race_list, mode, n_iterations=100):
        def dqa_iqa3(male, female, ref_male, ref_female):
            male_score = np.mean(male)
            female_score = np.mean(female)
            ref_male_score = np.mean(ref_male)
            ref_female_score = np.mean(ref_female)
            denom = abs((ref_male_score + ref_female_score) - (male_score + female_score))
            num = abs(male_score - female_score)
            return num / (denom + 1e-3)

        # Prepare bootstrap results
        DQA_bootstrap = {f"DQA{cat}": [] for cat in range(1, 7)}

        for _ in range(n_iterations):
            if mode == 'gender':
                # Initialize containers for categories
                categories = {f"cat{cat}_male": [] for cat in range(1, 7)}
                categories.update({f"cat{cat}_female": [] for cat in range(1, 7)})

                for race in race_list:
                    for cat in range(1, 7):
                        male_result = self.load_sdxl_results(profession, 'male', race, cat, metric_name='IQA3')
                        female_result = self.load_sdxl_results(profession, 'female', race, cat, metric_name='IQA3')

                        # Resample with replacement
                        male_resampled = resample(male_result[model_name], replace=True)
                        female_resampled = resample(female_result[model_name], replace=True)

                        # Extend categories
                        categories[f"cat{cat}_male"].extend(male_resampled)
                        categories[f"cat{cat}_female"].extend(female_resampled)

                # Compute DQA for each category
                for cat in range(1, 7):
                    male = np.array(categories[f"cat{cat}_male"])
                    female = np.array(categories[f"cat{cat}_female"])

                    if cat == 1:  # Special case for cat1: ref_male and male are the same
                        ref_male = male
                        ref_female = female
                    else:  # For other categories, reference is based on cat1
                        ref_male = np.array(categories["cat1_male"])
                        ref_female = np.array(categories["cat1_female"])

                    # Compute DQA for the current category
                    DQA_sample = dqa_iqa3(male, female, ref_male, ref_female)
                    DQA_bootstrap[f"DQA{cat}"].append(DQA_sample)

            elif mode == 'race':
                # Initialize containers for race combinations
                race_dict = {race: {f"cat{cat}": [] for cat in range(1, 7)} for race in race_list}

                for race in race_list:
                    for cat in range(1, 7):
                        male_result = self.load_sdxl_results(profession, 'male', race, cat, metric_name='IQA3')
                        female_result = self.load_sdxl_results(profession, 'female', race, cat, metric_name='IQA3')

                        # Concatenate male and female results with shape alignment
                        min_len = min(len(male_result[model_name]), len(female_result[model_name]))
                        male_resampled = resample(male_result[model_name][:min_len], replace=True)
                        female_resampled = resample(female_result[model_name][:min_len], replace=True)

                        race_dict[race][f"cat{cat}"].extend(np.concatenate([male_resampled, female_resampled]))

                # Compute DQA for race combinations
                for cat in range(1, 7):
                    for race1, race2 in itertools.combinations(race_list, 2):
                        cat_race1 = np.array(race_dict[race1][f"cat{cat}"])
                        cat_race2 = np.array(race_dict[race2][f"cat{cat}"])

                        # For race mode, ref is the two compared races
                        ref_race1 = cat_race1
                        ref_race2 = cat_race2

                        DQA_sample = dqa_iqa3(cat_race1, cat_race2, ref_race1, ref_race2)
                        DQA_bootstrap[f"DQA{cat}"].append(DQA_sample)

        return tuple(np.mean(DQA_bootstrap[f"DQA{cat}"]) for cat in range(1, 7))
    